// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0

package flex

// This file contains common test helpers for Autoflex tests.

import (
	"bytes"
	"context"
	"fmt"
	"path/filepath"
	"reflect"
	"testing"

	"github.com/google/go-cmp/cmp"
	"github.com/google/go-cmp/cmp/cmpopts"
	"github.com/hashicorp/terraform-plugin-framework/diag"
	"github.com/hashicorp/terraform-plugin-log/tflogtest"
)

type autoFlexTestCase struct {
	Options       []AutoFlexOptionsFunc
	Source        any
	Target        any
	ExpectedDiags diag.Diagnostics
	WantTarget    any
	WantDiff      bool
}

type autoFlexTestCases map[string]autoFlexTestCase

type runChecks struct {
	CompareDiags  bool
	CompareTarget bool
	GoldenLogs    bool // use golden snapshots for log comparison
}

// diagAF is a testing helper that creates a diag.Diagnostics containing
// a single diagnostic generated by calling diagFunc with reflect.TypeFor[T]().
func diagAF[T any](diagFunc func(reflect.Type) diag.ErrorDiagnostic) diag.Diagnostics {
	return diag.Diagnostics{
		diagFunc(reflect.TypeFor[T]()),
	}
}

// diagAFNil is a testing helper that creates a diag.Diagnostics containing
// a single diagnostic generated by calling diagFunc with nil.
// Use this for test cases where the type is unknown/nil (e.g., nil source/target).
func diagAFNil(diagFunc func(reflect.Type) diag.ErrorDiagnostic) diag.Diagnostics {
	return diag.Diagnostics{
		diagFunc(nil),
	}
}

// diagAF2 is a testing helper that creates a diag.Diagnostics containing
// a single diagnostic generated by calling diagFunc with reflect.TypeFor[T1]() and reflect.TypeFor[T2]().
// Use this for diagnostic functions that take two type parameters.
func diagAF2[T1, T2 any](diagFunc func(reflect.Type, reflect.Type) diag.ErrorDiagnostic) diag.Diagnostics {
	return diag.Diagnostics{
		diagFunc(reflect.TypeFor[T1](), reflect.TypeFor[T2]()),
	}
}

// diagAFTypeErr is a testing helper that creates a diag.Diagnostics containing
// a single diagnostic generated by calling diagFunc with reflect.TypeFor[T]() and the provided error.
// Use this for diagnostic functions that take a type and an error parameter.
func diagAFTypeErr[T any](diagFunc func(reflect.Type, error) diag.ErrorDiagnostic, err error) diag.Diagnostics {
	return diag.Diagnostics{
		diagFunc(reflect.TypeFor[T](), err),
	}
}

// diagAFEmpty is a testing helper that creates an empty diag.Diagnostics slice.
// Use this for test cases where no diagnostics are expected.
func diagAFEmpty() diag.Diagnostics {
	return diag.Diagnostics{}
}

// setFieldValue sets a field value in a struct using reflection
func setFieldValue(structPtr any, fieldName string, value any) {
	v := reflect.ValueOf(structPtr).Elem()
	field := v.FieldByName(fieldName)
	if field.IsValid() && field.CanSet() {
		field.Set(reflect.ValueOf(value))
	}
}

func runAutoExpandTestCases(t *testing.T, testCases autoFlexTestCases, checks runChecks) {
	t.Helper()
	for testName, tc := range testCases {
		t.Run(testName, func(t *testing.T) {
			t.Parallel()

			ctx := context.Background()
			var buf bytes.Buffer
			ctx = tflogtest.RootLogger(ctx, &buf)
			ctx = registerTestingLogger(ctx)

			diags := Expand(ctx, tc.Source, tc.Target, tc.Options...)

			if checks.CompareDiags {
				if diff := cmp.Diff(diags, tc.ExpectedDiags); diff != "" {
					t.Errorf("unexpected diagnostics difference: %s", diff)
				}
			}

			if checks.GoldenLogs {
				lines, err := tflogtest.MultilineJSONDecode(&buf)
				if err != nil {
					t.Fatalf("Expand: decoding log lines: %s", err)
				}
				normalizedLines := normalizeLogs(lines)

				goldenFileName := autoGenerateGoldenPath(t, t.Name(), testName)
				goldenPath := filepath.Join("testdata", goldenFileName)
				compareWithGolden(t, goldenPath, normalizedLines)
			}

			if checks.CompareTarget && !diags.HasError() {
				if diff := cmp.Diff(tc.Target, tc.WantTarget); diff != "" {
					t.Errorf("unexpected diff (+wanted, -got): %s", diff)
				}
			}
		})
	}
}

func runAutoFlattenTestCases(t *testing.T, testCases autoFlexTestCases, checks runChecks, opts ...cmp.Option) {
	t.Helper()

	for testName, testCase := range testCases {
		t.Run(testName, func(t *testing.T) {
			t.Parallel()

			ctx := context.Background()
			var buf bytes.Buffer
			ctx = tflogtest.RootLogger(ctx, &buf)
			ctx = registerTestingLogger(ctx)

			diags := Flatten(ctx, testCase.Source, testCase.Target, testCase.Options...)

			if checks.CompareDiags {
				if diff := cmp.Diff(diags, testCase.ExpectedDiags); diff != "" {
					t.Errorf("unexpected diagnostics difference: %s", diff)
				}
			}

			if checks.GoldenLogs {
				lines, err := tflogtest.MultilineJSONDecode(&buf)
				if err != nil {
					t.Fatalf("Flatten: decoding log lines: %s", err)
				}
				normalizedLines := normalizeLogs(lines)

				goldenFileName := autoGenerateGoldenPath(t, t.Name(), testName)
				goldenPath := filepath.Join("testdata", goldenFileName)
				compareWithGolden(t, goldenPath, normalizedLines)
			}

			if checks.CompareTarget && !diags.HasError() {
				less := func(a, b any) bool { return fmt.Sprintf("%+v", a) < fmt.Sprintf("%+v", b) }
				if diff := cmp.Diff(testCase.Target, testCase.WantTarget, append(opts, cmpopts.SortSlices(less))...); diff != "" {
					if !testCase.WantDiff {
						t.Errorf("unexpected diff (+wanted, -got): %s", diff)
					}
				}
			}
		})
	}
}

// Top-level tests need a concrete target type for some reason when calling `cmp.Diff`
type toplevelTestCase[Tsource, Ttarget any] struct {
	source        Tsource
	expectedValue Ttarget
	ExpectedDiags diag.Diagnostics
}

type toplevelTestCases[Tsource, Ttarget any] map[string]toplevelTestCase[Tsource, Ttarget]

func runTopLevelTestCases[Tsource, Ttarget any](t *testing.T, testCases toplevelTestCases[Tsource, Ttarget], checks runChecks) {
	t.Helper()

	for testName, testCase := range testCases {
		t.Run(testName, func(t *testing.T) {
			t.Parallel()

			ctx := context.Background()

			var buf bytes.Buffer
			ctx = tflogtest.RootLogger(ctx, &buf)

			ctx = registerTestingLogger(ctx)

			var target Ttarget
			diags := Flatten(ctx, testCase.source, &target)

			if checks.CompareDiags {
				if diff := cmp.Diff(diags, testCase.ExpectedDiags); diff != "" {
					t.Errorf("unexpected diagnostics difference: %s", diff)
				}
			}

			if checks.GoldenLogs {
				lines, err := tflogtest.MultilineJSONDecode(&buf)
				if err != nil {
					t.Fatalf("Flatten: decoding log lines: %s", err)
				}
				normalizedLines := normalizeLogs(lines)

				goldenFileName := autoGenerateGoldenPath(t, t.Name(), testName)
				goldenPath := filepath.Join("testdata", goldenFileName)
				compareWithGolden(t, goldenPath, normalizedLines)
			}

			if checks.CompareTarget && !diags.HasError() {
				less := func(a, b any) bool { return fmt.Sprintf("%+v", a) < fmt.Sprintf("%+v", b) }
				if diff := cmp.Diff(target, testCase.expectedValue, cmpopts.SortSlices(less)); diff != "" {
					t.Errorf("unexpected diff (+wanted, -got): %s", diff)
				}
			}
		})
	}
}
